package com.loginfail_detect;

import com.loginfail_detect.bean.LoginEvent;
import com.loginfail_detect.bean.LoginFailWarning;
import org.apache.commons.compress.utils.Lists;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.TimeCharacteristic;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.streaming.api.functions.timestamps.BoundedOutOfOrdernessTimestampExtractor;
import org.apache.flink.streaming.api.windowing.time.Time;
import org.apache.flink.util.Collector;

import java.util.ArrayList;
import java.util.Iterator;

/**
 * @Description: TODO QQ1667847363
 * @author: xiao kun tai
 * @date:2021/11/11 17:50
 *
 * 恶意登录监控
 * 2秒内登录次数不超过n次（onTimer）
 *
 * 2秒内登录次数不超过2次（没有使用定时器）（有缺陷）
 */
public class LoginFail {
    public static void main(String[] args) throws Exception {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(1);
        env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime);

        //从文件中读取数据
        String filePath = "LoginFailDetect/src/main/resources/LoginLog.csv";
        DataStream<String> fileStream = env.readTextFile(filePath);


        DataStream<LoginEvent> loginEventStream = fileStream
                .map(line -> {
                    String[] fields = line.split(",");
                    return new LoginEvent(new Long(fields[0]), fields[1], fields[2], new Long(fields[3]));
                })
                .assignTimestampsAndWatermarks(new BoundedOutOfOrdernessTimestampExtractor<LoginEvent>(Time.seconds(3)) {
                    @Override
                    public long extractTimestamp(LoginEvent loginEvent) {
                        return loginEvent.getTimestamp() * 1000L;
                    }
                });

        // 自定义处理函数检测连续登录失败事件
        SingleOutputStreamOperator<LoginFailWarning> warnningStream = loginEventStream
                .keyBy(LoginEvent::getUserId)
                .process(new LoginFailDetectWarning0(2));

        warnningStream.print();


        env.execute("login fail detect job");
    }

    /**
     * 实现自定义KeyedProcessFunction
     */
    public static class LoginFailDetectWarning0 extends KeyedProcessFunction<Long, LoginEvent, LoginFailWarning> {
        //定一个属性，最大连续登录失败次数
        private Integer maxFailTimes;

        public LoginFailDetectWarning0(Integer maxFailTimes) {
            this.maxFailTimes = maxFailTimes;
        }

        //定义状态：保存2秒内所有的登录失败事件
        ListState<LoginEvent> loginEventListState;

        //定义状态：保存注册的定时器时间戳
        ValueState<Long> timerTsState;

        @Override
        public void open(Configuration parameters) throws Exception {
            loginEventListState = getRuntimeContext()
                    .getListState(new ListStateDescriptor<LoginEvent>("login-fail-list",
                            LoginEvent.class));
            timerTsState = getRuntimeContext()
                    .getState(new ValueStateDescriptor<Long>("timer-ts",
                            Long.class));
        }

        @Override
        public void processElement(LoginEvent loginEvent, Context context, Collector<LoginFailWarning> collector) throws Exception {

            //判断当前登录事件类型
            if ("fail".equals(loginEvent.getLoginState())) {
                //如果是失败事件，添加到列表状态中
                loginEventListState.add(loginEvent);
                //如果没有定时器，注册一个2秒之后的定时器
                if (timerTsState.value() == null) {
                    Long ts = (loginEvent.getTimestamp() + 2) * 1000L;
                    context.timerService().registerEventTimeTimer(ts);
                    timerTsState.update(ts);
                }
            } else {
                //如果是登录成功，删除定时器，清空状态，重新开始
                if (timerTsState.value() != null)
                    context.timerService().deleteEventTimeTimer(timerTsState.value());
                loginEventListState.clear();
                timerTsState.clear();
            }

        }

        @Override
        public void onTimer(long timestamp, OnTimerContext ctx, Collector<LoginFailWarning> out) throws Exception {
            //定时器触发，说明2秒内没有登录成功，判断ListState中的失败的个数
            ArrayList<LoginEvent> loginFailEvents = Lists.newArrayList(loginEventListState.get().iterator());
            Integer failTimes = loginFailEvents.size();
            if (failTimes >= maxFailTimes) {
                //如果超出设定的最大失败的次数，输出报警
                out.collect(new LoginFailWarning(ctx.getCurrentKey(),
                        loginFailEvents.get(0).getTimestamp(),
                        loginFailEvents.get(failTimes - 1).getTimestamp(),
                        "login fail in 2s for " + failTimes + " times"));
            }

            //清空状态
            loginEventListState.clear();
            timerTsState.clear();

        }
    }


    /**
     * 优化
     */
    public static class LoginFailDetectWarning extends KeyedProcessFunction<Long, LoginEvent, LoginFailWarning> {
        //定一个属性，最大连续登录失败次数
        private Integer maxFailTimes;

        public LoginFailDetectWarning(Integer maxFailTimes) {
            this.maxFailTimes = maxFailTimes;
        }

        //定义状态：保存2秒内所有的登录失败事件
        ListState<LoginEvent> loginEventListState;


        @Override
        public void open(Configuration parameters) throws Exception {
            loginEventListState = getRuntimeContext()
                    .getListState(new ListStateDescriptor<LoginEvent>("login-fail-list",
                            LoginEvent.class));

        }

        // 以登录事件作为判断报警的触发条件，不再注册定时器。
        @Override
        public void processElement(LoginEvent loginEvent, Context context, Collector<LoginFailWarning> collector) throws Exception {
            //判断当前事件登录状态
            if ("fail".equals(loginEvent.getLoginState())) {
                //如果是登录失败，获取状态中之前的登录失败事件，继续判断是否已有失败事件
                Iterator<LoginEvent> iterator = loginEventListState.get().iterator();
                if (iterator.hasNext()) {
                    //如果已经有登录失败事件，继续判断时间戳是否在2秒之内
                    //获取已有的登录失败事件
                    LoginEvent firstFailEvent = iterator.next();
                    if (loginEvent.getTimestamp() - firstFailEvent.getTimestamp() <= 2){
                        //如果在2秒之内，输出报警
                        collector.collect(new LoginFailWarning(loginEvent.getUserId(),
                                firstFailEvent.getTimestamp(),
                                loginEvent.getTimestamp(),
                                "login fail 2 times in 2s"));
                    }

                    //不管报不报警，这次都已经处理，直接更新状态
                    loginEventListState.clear();
                    loginEventListState.add(loginEvent);


                } else {
                    //如果没有登录失败记录，直接将当前事件存入ListState
                    loginEventListState.add(loginEvent);
                }

            } else {
                //如果是登录成功，直接清空状态
                loginEventListState.clear();
            }


        }

    }
}
